{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "collapsed": true,
    "pycharm": {
     "name": "#%% md\n"
    }
   },
   "source": [
    "## 使用nn.Transformer和 torchtext建模\n",
    "\n",
    "> PyTorch 1.2 版本后，包括一个标准转换器模块\n",
    "> [学习来源](https://pytorch.org/tutorials/beginner/transformer_tutorial.html)\n",
    "> 参考：[【NLP】Transformer模型原理详解](https://zhuanlan.zhihu.com/p/44121378)\n",
    "\n",
    "\n",
    "### 模型定义\n",
    "\n",
    "训练nn.TransformerEncoder模型。 语言建模任务是为给定单词（或单词序列）遵循单词序列的可能性分配概率。 标记序列首先传递到嵌入层，然后传递到位置编码层以说明单词的顺序\n",
    "\n",
    "nn.TransformerEncoder由多层nn.TransformerEncoderLayer组成。 与输入序列一起，还需要一个正方形的注意掩码，因为nn.TransformerEncoder中的自注意层仅允许出现在该序列中的较早位置。 对于语言建模任务，应屏蔽将来头寸上的所有标记。 为了获得实际的单词，将nn.TransformerEncoder模型的输出发送到最终的Linear层，然后是对数 Softmax 函数。\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import math\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class PositionalEncoding(nn.Module):\n",
    "\n",
    "    def __init__(self, d_model, dropout=0.1, max_len=5000):\n",
    "        super(PositionalEncoding, self).__init__()\n",
    "        self.dropout = nn.Dropout(p=dropout)\n",
    "\n",
    "        pe = torch.zeros(max_len, d_model)\n",
    "        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)\n",
    "        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))\n",
    "        pe[:, 0::2] = torch.sin(position * div_term)\n",
    "        pe[:, 1::2] = torch.cos(position * div_term)\n",
    "        pe = pe.unsqueeze(0).transpose(0, 1)\n",
    "        self.register_buffer('pe', pe)\n",
    "\n",
    "    def forward(self, x):\n",
    "        x = x + self.pe[:x.size(0), :]\n",
    "        return self.dropout(x)\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "class TransformerModel(nn.Module):\n",
    "\n",
    "    def __init__(self, ntoken, ninp, nhead, nhid, nlayers, dropout=0.5):\n",
    "        super(TransformerModel, self).__init__()\n",
    "        from torch.nn import TransformerEncoder, TransformerEncoderLayer\n",
    "        self.model_type = 'Transformer'\n",
    "        self.pos_encoder = PositionalEncoding(ninp, dropout)\n",
    "        encoder_layers = TransformerEncoderLayer(ninp, nhead, nhid, dropout)\n",
    "        self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)\n",
    "        self.encoder = nn.Embedding(ntoken, ninp)\n",
    "        self.ninp = ninp\n",
    "        self.decoder = nn.Linear(ninp, ntoken)\n",
    "\n",
    "        self.init_weights()\n",
    "\n",
    "    def generate_square_subsequent_mask(self, sz):\n",
    "        mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)\n",
    "        mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))\n",
    "        return mask\n",
    "\n",
    "    def init_weights(self):\n",
    "        initrange = 0.1\n",
    "        self.encoder.weight.data.uniform_(-initrange, initrange)\n",
    "        self.decoder.bias.data.zero_()\n",
    "        self.decoder.weight.data.uniform_(-initrange, initrange)\n",
    "\n",
    "    def forward(self, src, src_mask):\n",
    "        src = self.encoder(src) * math.sqrt(self.ninp)\n",
    "        src = self.pos_encoder(src)\n",
    "        output = self.transformer_encoder(src, src_mask)\n",
    "        output = self.decoder(output)\n",
    "        return output"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 加载和批量数据\n",
    "本教程使用torchtext生成 Wikitext-2 数据集。\n",
    "vocab对象是基于训练数据集构建的，用于将标记数字化为张量。\n",
    "从序列数据开始，batchify()函数将数据集排列为列，以修剪掉数据分成大小为batch_size的批量后剩余的所有标记。\n",
    "\n",
    "例如，以字母为序列（总长度为 26）并且批大小为 4，我们将字母分为 4 个长度为 6 的序列：\n",
    "\n",
    "这些列被模型视为独立的，这意味着无法了解G和F的依赖性，但可以进行更有效的批量"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "outputs": [],
   "source": [
    "import io\n",
    "import torch\n",
    "from torchtext.utils import download_from_url, extract_archive\n",
    "from torchtext.data.utils import get_tokenizer\n",
    "from torchtext.vocab import build_vocab_from_iterator\n"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "outputs": [
    {
     "ename": "ConnectionError",
     "evalue": "HTTPSConnectionPool(host='s3.amazonaws.com', port=443): Max retries exceeded with url: /research.metamind.io/wikitext/wikitext-2-v1.zip (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x000001DD3737B640>: Failed to establish a new connection: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应，连接尝试失败。'))",
     "output_type": "error",
     "traceback": [
      "\u001B[1;31m---------------------------------------------------------------------------\u001B[0m",
      "\u001B[1;31mTimeoutError\u001B[0m                              Traceback (most recent call last)",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\urllib3\\connection.py\u001B[0m in \u001B[0;36m_new_conn\u001B[1;34m(self)\u001B[0m\n\u001B[0;32m    158\u001B[0m         \u001B[1;32mtry\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 159\u001B[1;33m             conn = connection.create_connection(\n\u001B[0m\u001B[0;32m    160\u001B[0m                 \u001B[1;33m(\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m_dns_host\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mport\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mtimeout\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mextra_kw\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\urllib3\\util\\connection.py\u001B[0m in \u001B[0;36mcreate_connection\u001B[1;34m(address, timeout, source_address, socket_options)\u001B[0m\n\u001B[0;32m     83\u001B[0m     \u001B[1;32mif\u001B[0m \u001B[0merr\u001B[0m \u001B[1;32mis\u001B[0m \u001B[1;32mnot\u001B[0m \u001B[1;32mNone\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 84\u001B[1;33m         \u001B[1;32mraise\u001B[0m \u001B[0merr\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m     85\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\urllib3\\util\\connection.py\u001B[0m in \u001B[0;36mcreate_connection\u001B[1;34m(address, timeout, source_address, socket_options)\u001B[0m\n\u001B[0;32m     73\u001B[0m                 \u001B[0msock\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mbind\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0msource_address\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 74\u001B[1;33m             \u001B[0msock\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mconnect\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0msa\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m     75\u001B[0m             \u001B[1;32mreturn\u001B[0m \u001B[0msock\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;31mTimeoutError\u001B[0m: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应，连接尝试失败。",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001B[1;31mNewConnectionError\u001B[0m                        Traceback (most recent call last)",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\urllib3\\connectionpool.py\u001B[0m in \u001B[0;36murlopen\u001B[1;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)\u001B[0m\n\u001B[0;32m    669\u001B[0m             \u001B[1;31m# Make the request on the httplib connection object.\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 670\u001B[1;33m             httplib_response = self._make_request(\n\u001B[0m\u001B[0;32m    671\u001B[0m                 \u001B[0mconn\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\urllib3\\connectionpool.py\u001B[0m in \u001B[0;36m_make_request\u001B[1;34m(self, conn, method, url, timeout, chunked, **httplib_request_kw)\u001B[0m\n\u001B[0;32m    380\u001B[0m         \u001B[1;32mtry\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 381\u001B[1;33m             \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m_validate_conn\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mconn\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m    382\u001B[0m         \u001B[1;32mexcept\u001B[0m \u001B[1;33m(\u001B[0m\u001B[0mSocketTimeout\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mBaseSSLError\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;32mas\u001B[0m \u001B[0me\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\urllib3\\connectionpool.py\u001B[0m in \u001B[0;36m_validate_conn\u001B[1;34m(self, conn)\u001B[0m\n\u001B[0;32m    977\u001B[0m         \u001B[1;32mif\u001B[0m \u001B[1;32mnot\u001B[0m \u001B[0mgetattr\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mconn\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;34m\"sock\"\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;32mNone\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m  \u001B[1;31m# AppEngine might not have  `.sock`\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 978\u001B[1;33m             \u001B[0mconn\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mconnect\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m    979\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\urllib3\\connection.py\u001B[0m in \u001B[0;36mconnect\u001B[1;34m(self)\u001B[0m\n\u001B[0;32m    308\u001B[0m         \u001B[1;31m# Add certificate verification\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 309\u001B[1;33m         \u001B[0mconn\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m_new_conn\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m    310\u001B[0m         \u001B[0mhostname\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mhost\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\urllib3\\connection.py\u001B[0m in \u001B[0;36m_new_conn\u001B[1;34m(self)\u001B[0m\n\u001B[0;32m    170\u001B[0m         \u001B[1;32mexcept\u001B[0m \u001B[0mSocketError\u001B[0m \u001B[1;32mas\u001B[0m \u001B[0me\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 171\u001B[1;33m             raise NewConnectionError(\n\u001B[0m\u001B[0;32m    172\u001B[0m                 \u001B[0mself\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;34m\"Failed to establish a new connection: %s\"\u001B[0m \u001B[1;33m%\u001B[0m \u001B[0me\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;31mNewConnectionError\u001B[0m: <urllib3.connection.HTTPSConnection object at 0x000001DD3737B640>: Failed to establish a new connection: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应，连接尝试失败。",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001B[1;31mMaxRetryError\u001B[0m                             Traceback (most recent call last)",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\requests\\adapters.py\u001B[0m in \u001B[0;36msend\u001B[1;34m(self, request, stream, timeout, verify, cert, proxies)\u001B[0m\n\u001B[0;32m    438\u001B[0m             \u001B[1;32mif\u001B[0m \u001B[1;32mnot\u001B[0m \u001B[0mchunked\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 439\u001B[1;33m                 resp = conn.urlopen(\n\u001B[0m\u001B[0;32m    440\u001B[0m                     \u001B[0mmethod\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mrequest\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmethod\u001B[0m\u001B[1;33m,\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\urllib3\\connectionpool.py\u001B[0m in \u001B[0;36murlopen\u001B[1;34m(self, method, url, body, headers, retries, redirect, assert_same_host, timeout, pool_timeout, release_conn, chunked, body_pos, **response_kw)\u001B[0m\n\u001B[0;32m    725\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 726\u001B[1;33m             retries = retries.increment(\n\u001B[0m\u001B[0;32m    727\u001B[0m                 \u001B[0mmethod\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0murl\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0merror\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0me\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0m_pool\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mself\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0m_stacktrace\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0msys\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mexc_info\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m[\u001B[0m\u001B[1;36m2\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\urllib3\\util\\retry.py\u001B[0m in \u001B[0;36mincrement\u001B[1;34m(self, method, url, response, error, _pool, _stacktrace)\u001B[0m\n\u001B[0;32m    445\u001B[0m         \u001B[1;32mif\u001B[0m \u001B[0mnew_retry\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mis_exhausted\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 446\u001B[1;33m             \u001B[1;32mraise\u001B[0m \u001B[0mMaxRetryError\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0m_pool\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0murl\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0merror\u001B[0m \u001B[1;32mor\u001B[0m \u001B[0mResponseError\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mcause\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m    447\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;31mMaxRetryError\u001B[0m: HTTPSConnectionPool(host='s3.amazonaws.com', port=443): Max retries exceeded with url: /research.metamind.io/wikitext/wikitext-2-v1.zip (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x000001DD3737B640>: Failed to establish a new connection: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应，连接尝试失败。'))",
      "\nDuring handling of the above exception, another exception occurred:\n",
      "\u001B[1;31mConnectionError\u001B[0m                           Traceback (most recent call last)",
      "\u001B[1;32m~\\AppData\\Local\\Temp/ipykernel_16620/3517965082.py\u001B[0m in \u001B[0;36m<module>\u001B[1;34m\u001B[0m\n\u001B[0;32m      1\u001B[0m \u001B[1;32mfrom\u001B[0m \u001B[0mtorchtext\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mdatasets\u001B[0m \u001B[1;32mimport\u001B[0m \u001B[0mWikiText2\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m----> 2\u001B[1;33m \u001B[0mtrain_data\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mvalid_data\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mtest_data\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mWikiText2\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m      3\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m      4\u001B[0m \u001B[0murl\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;34m'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip'\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m      5\u001B[0m \u001B[0mtest_filepath\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mvalid_filepath\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mtrain_filepath\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mextract_archive\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mdownload_from_url\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0murl\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\torchtext\\data\\datasets_utils.py\u001B[0m in \u001B[0;36mwrapper\u001B[1;34m(root, *args, **kwargs)\u001B[0m\n\u001B[0;32m    255\u001B[0m             \u001B[1;32mif\u001B[0m \u001B[1;32mnot\u001B[0m \u001B[0mos\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mpath\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mexists\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mnew_root\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    256\u001B[0m                 \u001B[0mos\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mmakedirs\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mnew_root\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 257\u001B[1;33m             \u001B[1;32mreturn\u001B[0m \u001B[0mfunc\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mroot\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mnew_root\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m*\u001B[0m\u001B[0margs\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m    258\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    259\u001B[0m         \u001B[1;32mreturn\u001B[0m \u001B[0mwrapper\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\torchtext\\data\\datasets_utils.py\u001B[0m in \u001B[0;36mnew_fn\u001B[1;34m(root, split, **kwargs)\u001B[0m\n\u001B[0;32m    217\u001B[0m         \u001B[0mresult\u001B[0m \u001B[1;33m=\u001B[0m \u001B[1;33m[\u001B[0m\u001B[1;33m]\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    218\u001B[0m         \u001B[1;32mfor\u001B[0m \u001B[0mitem\u001B[0m \u001B[1;32min\u001B[0m \u001B[0m_check_default_set\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0msplit\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0msplits\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mfn\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0m__name__\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 219\u001B[1;33m             \u001B[0mresult\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mappend\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mfn\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mroot\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mitem\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m    220\u001B[0m         \u001B[1;32mreturn\u001B[0m \u001B[0m_wrap_datasets\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mtuple\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mresult\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0msplit\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    221\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\torchtext\\datasets\\wikitext2.py\u001B[0m in \u001B[0;36mWikiText2\u001B[1;34m(root, split)\u001B[0m\n\u001B[0;32m     27\u001B[0m \u001B[1;33m@\u001B[0m\u001B[0m_wrap_split_argument\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m'train'\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;34m'valid'\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;34m'test'\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m     28\u001B[0m \u001B[1;32mdef\u001B[0m \u001B[0mWikiText2\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mroot\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0msplit\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 29\u001B[1;33m     \u001B[0mdataset_tar\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mdownload_from_url\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mURL\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mroot\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mroot\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mhash_value\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mMD5\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mhash_type\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;34m'md5'\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m     30\u001B[0m     \u001B[0mextracted_files\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mextract_archive\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mdataset_tar\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m     31\u001B[0m     \u001B[0mpath\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0m_find_match\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0msplit\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mextracted_files\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\torchtext\\utils.py\u001B[0m in \u001B[0;36mdownload_from_url\u001B[1;34m(url, path, root, overwrite, hash_value, hash_type)\u001B[0m\n\u001B[0;32m    118\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    119\u001B[0m     \u001B[1;32mif\u001B[0m \u001B[1;34m'drive.google.com'\u001B[0m \u001B[1;32mnot\u001B[0m \u001B[1;32min\u001B[0m \u001B[0murl\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 120\u001B[1;33m         \u001B[0mresponse\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mrequests\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mget\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0murl\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mheaders\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;33m{\u001B[0m\u001B[1;34m'User-Agent'\u001B[0m\u001B[1;33m:\u001B[0m \u001B[1;34m'Mozilla/5.0'\u001B[0m\u001B[1;33m}\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mstream\u001B[0m\u001B[1;33m=\u001B[0m\u001B[1;32mTrue\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m    121\u001B[0m         \u001B[1;32mreturn\u001B[0m \u001B[0m_process_response\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mresponse\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mroot\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mfilename\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    122\u001B[0m     \u001B[1;32melse\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\requests\\api.py\u001B[0m in \u001B[0;36mget\u001B[1;34m(url, params, **kwargs)\u001B[0m\n\u001B[0;32m     74\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m     75\u001B[0m     \u001B[0mkwargs\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0msetdefault\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m'allow_redirects'\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;32mTrue\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 76\u001B[1;33m     \u001B[1;32mreturn\u001B[0m \u001B[0mrequest\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;34m'get'\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0murl\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mparams\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mparams\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m     77\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m     78\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\requests\\api.py\u001B[0m in \u001B[0;36mrequest\u001B[1;34m(method, url, **kwargs)\u001B[0m\n\u001B[0;32m     59\u001B[0m     \u001B[1;31m# cases, and look like a memory leak in others.\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m     60\u001B[0m     \u001B[1;32mwith\u001B[0m \u001B[0msessions\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mSession\u001B[0m\u001B[1;33m(\u001B[0m\u001B[1;33m)\u001B[0m \u001B[1;32mas\u001B[0m \u001B[0msession\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m---> 61\u001B[1;33m         \u001B[1;32mreturn\u001B[0m \u001B[0msession\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mrequest\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mmethod\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mmethod\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0murl\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0murl\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m     62\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m     63\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\requests\\sessions.py\u001B[0m in \u001B[0;36mrequest\u001B[1;34m(self, method, url, params, data, headers, cookies, files, auth, timeout, allow_redirects, proxies, hooks, stream, verify, cert, json)\u001B[0m\n\u001B[0;32m    528\u001B[0m         }\n\u001B[0;32m    529\u001B[0m         \u001B[0msend_kwargs\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0mupdate\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0msettings\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 530\u001B[1;33m         \u001B[0mresp\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0mself\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0msend\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mprep\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0msend_kwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m    531\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    532\u001B[0m         \u001B[1;32mreturn\u001B[0m \u001B[0mresp\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\requests\\sessions.py\u001B[0m in \u001B[0;36msend\u001B[1;34m(self, request, **kwargs)\u001B[0m\n\u001B[0;32m    641\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    642\u001B[0m         \u001B[1;31m# Send the request\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 643\u001B[1;33m         \u001B[0mr\u001B[0m \u001B[1;33m=\u001B[0m \u001B[0madapter\u001B[0m\u001B[1;33m.\u001B[0m\u001B[0msend\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0mrequest\u001B[0m\u001B[1;33m,\u001B[0m \u001B[1;33m**\u001B[0m\u001B[0mkwargs\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m    644\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    645\u001B[0m         \u001B[1;31m# Total elapsed time of the request (approximately)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;32mc:\\users\\dengxixi\\miniconda3\\envs\\torch-ocr\\lib\\site-packages\\requests\\adapters.py\u001B[0m in \u001B[0;36msend\u001B[1;34m(self, request, stream, timeout, verify, cert, proxies)\u001B[0m\n\u001B[0;32m    514\u001B[0m                 \u001B[1;32mraise\u001B[0m \u001B[0mSSLError\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0me\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mrequest\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mrequest\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    515\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[1;32m--> 516\u001B[1;33m             \u001B[1;32mraise\u001B[0m \u001B[0mConnectionError\u001B[0m\u001B[1;33m(\u001B[0m\u001B[0me\u001B[0m\u001B[1;33m,\u001B[0m \u001B[0mrequest\u001B[0m\u001B[1;33m=\u001B[0m\u001B[0mrequest\u001B[0m\u001B[1;33m)\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0m\u001B[0;32m    517\u001B[0m \u001B[1;33m\u001B[0m\u001B[0m\n\u001B[0;32m    518\u001B[0m         \u001B[1;32mexcept\u001B[0m \u001B[0mClosedPoolError\u001B[0m \u001B[1;32mas\u001B[0m \u001B[0me\u001B[0m\u001B[1;33m:\u001B[0m\u001B[1;33m\u001B[0m\u001B[1;33m\u001B[0m\u001B[0m\n",
      "\u001B[1;31mConnectionError\u001B[0m: HTTPSConnectionPool(host='s3.amazonaws.com', port=443): Max retries exceeded with url: /research.metamind.io/wikitext/wikitext-2-v1.zip (Caused by NewConnectionError('<urllib3.connection.HTTPSConnection object at 0x000001DD3737B640>: Failed to establish a new connection: [WinError 10060] 由于连接方在一段时间后没有正确答复或连接的主机没有反应，连接尝试失败。'))"
     ]
    }
   ],
   "source": [
    "from torchtext.datasets import WikiText2\n",
    "train_data, valid_data, test_data = WikiText2()\n",
    "\n",
    "# url = 'https://s3.amazonaws.com/research.metamind.io/wikitext/wikitext-2-v1.zip'\n",
    "# test_filepath, valid_filepath, train_filepath = extract_archive(download_from_url(url))\n",
    "tokenizer = get_tokenizer('basic_english')\n",
    "def yield_tokens(data_iter):\n",
    "    for _, text in data_iter:\n",
    "        yield tokenizer(text)\n",
    "\n",
    "vocab = build_vocab_from_iterator(yield_tokens(train_data), specials=[\"<unk>\"])\n",
    "# vocab = build_vocab_from_iterator(map(tokenizer, iter(io.open(train_filepath, encoding=\"utf8\"))))\n",
    "\n",
    "print('vocab size:', len(vocab))"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "def data_process(raw_text_iter):\n",
    "  data = [torch.tensor([vocab[token] for token in tokenizer(item)],\n",
    "                       dtype=torch.long) for item in raw_text_iter]\n",
    "  return torch.cat(tuple(filter(lambda t: t.numel() > 0, data)))\n",
    "\n",
    "train_data = data_process(iter(io.open(train_filepath, encoding=\"utf8\")))\n",
    "val_data = data_process(iter(io.open(valid_filepath, encoding=\"utf8\")))\n",
    "test_data = data_process(iter(io.open(test_filepath, encoding=\"utf8\")))\n",
    "\n",
    "device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
    "\n",
    "def batchify(data, bsz):\n",
    "    # Divide the dataset into bsz parts.\n",
    "    nbatch = data.size(0) // bsz\n",
    "    # Trim off any extra elements that wouldn't cleanly fit (remainders).\n",
    "    data = data.narrow(0, 0, nbatch * bsz)\n",
    "    # Evenly divide the data across the bsz batches.\n",
    "    data = data.view(bsz, -1).t().contiguous()\n",
    "    return data.to(device)\n",
    "\n",
    "batch_size = 20\n",
    "eval_batch_size = 10\n",
    "train_data = batchify(train_data, batch_size)\n",
    "val_data = batchify(val_data, eval_batch_size)\n",
    "test_data = batchify(test_data, eval_batch_size)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    " ### 生成输入序列和目标序列的函数\n",
    "get_batch()函数为转换器模型生成输入和目标序列。 它将源数据细分为长度为bptt的块。 对于语言建模任务，模型需要以下单词作为Target。 例如，如果bptt值为 2，则i = 0时，我们将获得以下两个变量：\n",
    "\n",
    "../_img/transformer_input_target.png\n",
    "\n",
    "应该注意的是，这些块沿着维度 0，与Transformer模型中的S维度一致。 批量尺寸N沿尺寸 1。"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "bptt = 35\n",
    "def get_batch(source, i):\n",
    "    seq_len = min(bptt, len(source) - 1 - i)\n",
    "    data = source[i:i+seq_len]\n",
    "    target = source[i+1:i+1+seq_len].reshape(-1)\n",
    "    return data, targetCopy"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 启动实例\n",
    "使用下面的超参数建立模型。 vocab的大小等于vocab对象的长度。"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "ntokens = len(vocab.stoi) # the size of vocabulary\n",
    "emsize = 200 # embedding dimension\n",
    "nhid = 200 # the dimension of the feedforward network model in nn.TransformerEncoder\n",
    "nlayers = 2 # the number of nn.TransformerEncoderLayer in nn.TransformerEncoder\n",
    "nhead = 2 # the number of heads in the multiheadattention models\n",
    "dropout = 0.2 # the dropout value\n",
    "model = TransformerModel(ntokens, emsize, nhead, nhid, nlayers, dropout).to(device)Copy"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "### 运行模型\n",
    "CrossEntropyLoss用于跟踪损失，\n",
    "SGD实现随机梯度下降方法作为优化器。\n",
    " 初始学习率设置为 5.0。 StepLR用于通过历时调整学习率。\n",
    " 在训练期间，我们使用nn.utils.clip_grad_norm_函数将所有梯度缩放在一起，以防止爆炸。"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "criterion = nn.CrossEntropyLoss()\n",
    "lr = 5.0 # learning rate\n",
    "optimizer = torch.optim.SGD(model.parameters(), lr=lr)\n",
    "scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.95)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "import time\n",
    "def train():\n",
    "    model.train() # Turn on the train mode\n",
    "    total_loss = 0.\n",
    "    start_time = time.time()\n",
    "    src_mask = model.generate_square_subsequent_mask(bptt).to(device)\n",
    "    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):\n",
    "        data, targets = get_batch(train_data, i)\n",
    "        optimizer.zero_grad()\n",
    "        if data.size(0) != bptt:\n",
    "            src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)\n",
    "        output = model(data, src_mask)\n",
    "        loss = criterion(output.view(-1, ntokens), targets)\n",
    "        loss.backward()\n",
    "        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)\n",
    "        optimizer.step()\n",
    "\n",
    "        total_loss += loss.item()\n",
    "        log_interval = 200\n",
    "        if batch % log_interval == 0 and batch > 0:\n",
    "            cur_loss = total_loss / log_interval\n",
    "            elapsed = time.time() - start_time\n",
    "            print('| epoch {:3d} | {:5d}/{:5d} batches | '\n",
    "                  'lr {:02.2f} | ms/batch {:5.2f} | '\n",
    "                  'loss {:5.2f} | ppl {:8.2f}'.format(\n",
    "                    epoch, batch, len(train_data) // bptt, scheduler.get_lr()[0],\n",
    "                    elapsed * 1000 / log_interval,\n",
    "                    cur_loss, math.exp(cur_loss)))\n",
    "            total_loss = 0\n",
    "            start_time = time.time()\n",
    "\n",
    "def evaluate(eval_model, data_source):\n",
    "    eval_model.eval() # Turn on the evaluation mode\n",
    "    total_loss = 0.\n",
    "    src_mask = model.generate_square_subsequent_mask(bptt).to(device)\n",
    "    with torch.no_grad():\n",
    "        for i in range(0, data_source.size(0) - 1, bptt):\n",
    "            data, targets = get_batch(data_source, i)\n",
    "            if data.size(0) != bptt:\n",
    "                src_mask = model.generate_square_subsequent_mask(data.size(0)).to(device)\n",
    "            output = eval_model(data, src_mask)\n",
    "            output_flat = output.view(-1, ntokens)\n",
    "            total_loss += len(data) * criterion(output_flat, targets).item()\n",
    "    return total_loss / (len(data_source) - 1)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "循环遍历。 如果验证损失是迄今为止迄今为止最好的，请保存模型。 在每个周期之后调整学习率。"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "best_val_loss = float(\"inf\")\n",
    "epochs = 3 # The number of epochs\n",
    "best_model = None\n",
    "\n",
    "for epoch in range(1, epochs + 1):\n",
    "    epoch_start_time = time.time()\n",
    "    train()\n",
    "    val_loss = evaluate(model, val_data)\n",
    "    print('-' * 89)\n",
    "    print('| end of epoch {:3d} | time: {:5.2f}s | valid loss {:5.2f} | '\n",
    "          'valid ppl {:8.2f}'.format(epoch, (time.time() - epoch_start_time),\n",
    "                                     val_loss, math.exp(val_loss)))\n",
    "    print('-' * 89)\n",
    "\n",
    "    if val_loss < best_val_loss:\n",
    "        best_val_loss = val_loss\n",
    "        best_model = model\n",
    "\n",
    "    scheduler.step()Copy"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  },
  {
   "cell_type": "markdown",
   "source": [
    "使用测试数据集评估模型\n",
    "应用最佳模型以检查测试数据集的结果。"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%% md\n"
    }
   }
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "outputs": [],
   "source": [
    "test_loss = evaluate(best_model, test_data)\n",
    "print('=' * 89)\n",
    "print('| End of training | test loss {:5.2f} | test ppl {:8.2f}'.format(\n",
    "    test_loss, math.exp(test_loss)))\n",
    "print('=' * 89)"
   ],
   "metadata": {
    "collapsed": false,
    "pycharm": {
     "name": "#%%\n"
    }
   }
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 2
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython2",
   "version": "2.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}